<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN"
                "http://www.w3.org/TR/REC-html40/loose.dtd">
<html>
<head>
  <title>Description of mdngrad</title>
  <meta name="keywords" content="mdngrad">
  <meta name="description" content="MDNGRAD Evaluate gradient of error function for Mixture Density Network.">
  <meta http-equiv="Content-Type" content="text/html; charset=iso-8859-1">
  <meta name="generator" content="m2html &copy; 2003 Guillaume Flandin">
  <meta name="robots" content="index, follow">
  <link type="text/css" rel="stylesheet" href="../../m2html.css">
</head>
<body>
<a name="_top"></a>
<div><a href="../../menu.html">Home</a> &gt;  <a href="#">ReBEL-0.2.7</a> &gt; <a href="#">netlab</a> &gt; mdngrad.m</div>

<!--<table width="100%"><tr><td align="left"><a href="../../menu.html"><img alt="<" border="0" src="../../left.png">&nbsp;Master index</a></td>
<td align="right"><a href="menu.html">Index for .\ReBEL-0.2.7\netlab&nbsp;<img alt=">" border="0" src="../../right.png"></a></td></tr></table>-->

<h1>mdngrad
</h1>

<h2><a name="_name"></a>PURPOSE <a href="#_top"><img alt="^" border="0" src="../../up.png"></a></h2>
<div class="box"><strong>MDNGRAD Evaluate gradient of error function for Mixture Density Network.</strong></div>

<h2><a name="_synopsis"></a>SYNOPSIS <a href="#_top"><img alt="^" border="0" src="../../up.png"></a></h2>
<div class="box"><strong>function g = mdngrad(net, x, t) </strong></div>

<h2><a name="_description"></a>DESCRIPTION <a href="#_top"><img alt="^" border="0" src="../../up.png"></a></h2>
<div class="fragment"><pre class="comment">MDNGRAD Evaluate gradient of error function for Mixture Density Network.

    Description
     G = MDNGRAD(NET, X, T) takes a mixture density network data
    structure NET, a matrix X of input vectors and a matrix T of target
    vectors, and evaluates the gradient G of the error function with
    respect to the network weights. The error function is negative log
    likelihood of the target data.  Each row of X corresponds to one
    input vector and each row of T corresponds to one target vector.

    See also
    <a href="mdn.html" class="code" title="function net = mdn(nin, nhidden, ncentres, dim_target, mix_type,prior, beta)">MDN</a>, <a href="mdnfwd.html" class="code" title="function [mixparams, y, z, a] = mdnfwd(net, x)">MDNFWD</a>, <a href="mdnerr.html" class="code" title="function e = mdnerr(net, x, t)">MDNERR</a>, <a href="mdnprob.html" class="code" title="function [prob,a] = mdnprob(mixparams, t)">MDNPROB</a>, <a href="mlpbkp.html" class="code" title="function g = mlpbkp(net, x, z, deltas)">MLPBKP</a></pre></div>

<!-- crossreference -->
<h2><a name="_cross"></a>CROSS-REFERENCE INFORMATION <a href="#_top"><img alt="^" border="0" src="../../up.png"></a></h2>
This function calls:
<ul style="list-style-image:url(../../matlabicon.gif)">
<li><a href="consist.html" class="code" title="function errstring = consist(model, type, inputs, outputs)">consist</a>	CONSIST Check that arguments are consistent.</li><li><a href="dist2.html" class="code" title="function n2 = dist2(x, c)">dist2</a>	DIST2	Calculates squared distance between two sets of points.</li><li><a href="mdndist2.html" class="code" title="function n2 = mdndist2(mixparams, t)">mdndist2</a>	MDNDIST2 Calculates squared distance between centres of Gaussian kernels and data</li><li><a href="mdnfwd.html" class="code" title="function [mixparams, y, z, a] = mdnfwd(net, x)">mdnfwd</a>	MDNFWD	Forward propagation through Mixture Density Network.</li><li><a href="mdnpost.html" class="code" title="function [post, a] = mdnpost(mixparams, t)">mdnpost</a>	MDNPOST Computes the posterior probability for each MDN mixture component.</li><li><a href="mlpbkp.html" class="code" title="function g = mlpbkp(net, x, z, deltas)">mlpbkp</a>	MLPBKP	Backpropagate gradient of error function for 2-layer network.</li></ul>
This function is called by:
<ul style="list-style-image:url(../../matlabicon.gif)">
</ul>
<!-- crossreference -->


<h2><a name="_source"></a>SOURCE CODE <a href="#_top"><img alt="^" border="0" src="../../up.png"></a></h2>
<div class="fragment"><pre>0001 <a name="_sub0" href="#_subfunctions" class="code">function g = mdngrad(net, x, t)</a>
0002 <span class="comment">%MDNGRAD Evaluate gradient of error function for Mixture Density Network.</span>
0003 <span class="comment">%</span>
0004 <span class="comment">%    Description</span>
0005 <span class="comment">%     G = MDNGRAD(NET, X, T) takes a mixture density network data</span>
0006 <span class="comment">%    structure NET, a matrix X of input vectors and a matrix T of target</span>
0007 <span class="comment">%    vectors, and evaluates the gradient G of the error function with</span>
0008 <span class="comment">%    respect to the network weights. The error function is negative log</span>
0009 <span class="comment">%    likelihood of the target data.  Each row of X corresponds to one</span>
0010 <span class="comment">%    input vector and each row of T corresponds to one target vector.</span>
0011 <span class="comment">%</span>
0012 <span class="comment">%    See also</span>
0013 <span class="comment">%    MDN, MDNFWD, MDNERR, MDNPROB, MLPBKP</span>
0014 <span class="comment">%</span>
0015 
0016 <span class="comment">%    Copyright (c) Ian T Nabney (1996-2001)</span>
0017 <span class="comment">%    David J Evans (1998)</span>
0018 
0019 <span class="comment">% Check arguments for consistency</span>
0020 errstring = <a href="consist.html" class="code" title="function errstring = consist(model, type, inputs, outputs)">consist</a>(net, <span class="string">'mdn'</span>, x, t);
0021 <span class="keyword">if</span> ~isempty(errstring)
0022   error(errstring);
0023 <span class="keyword">end</span>
0024 
0025 [mixparams, y, z] = <a href="mdnfwd.html" class="code" title="function [mixparams, y, z, a] = mdnfwd(net, x)">mdnfwd</a>(net, x);
0026 
0027 <span class="comment">% Compute gradients at MLP outputs: put the answer in deltas</span>
0028 ncentres = net.mdnmixes.ncentres;
0029 dim_target = net.mdnmixes.dim_target;
0030 nmixparams = net.mdnmixes.nparams;
0031 ntarget = size(t, 1);
0032 deltas = zeros(ntarget, net.mlp.nout);
0033 e = ones(ncentres, 1);
0034 f = ones(1, dim_target);
0035 
0036 post = <a href="mdnpost.html" class="code" title="function [post, a] = mdnpost(mixparams, t)">mdnpost</a>(mixparams, t);
0037 
0038 <span class="comment">% Calculate prior derivatives</span>
0039 deltas(:,1:ncentres)  = mixparams.mixcoeffs - post;
0040 
0041 <span class="comment">% Calculate centre derivatives</span>
0042 long_t = kron(ones(1, ncentres), t);
0043 centre_err = mixparams.centres - long_t;
0044 
0045 <span class="comment">% Get the post to match each u_jk:</span>
0046 <span class="comment">% this array will be (ntarget, (ncentres*dim_target))</span>
0047 long_post = kron(ones(dim_target, 1), post);
0048 long_post = reshape(long_post, ntarget, (ncentres*dim_target));
0049 
0050 <span class="comment">% Get the variance to match each u_jk:</span>
0051 var = mixparams.covars;
0052 var = kron(ones(dim_target, 1), var);
0053 var = reshape(var, ntarget, (ncentres*dim_target));
0054 
0055 <span class="comment">% Compute centre deltas</span>
0056 deltas(:, (ncentres+1):(ncentres*(1+dim_target))) = <span class="keyword">...</span>
0057                        (centre_err.*long_post)./var;
0058 
0059 <span class="comment">% Compute variance deltas</span>
0060 <a href="dist2.html" class="code" title="function n2 = dist2(x, c)">dist2</a>             = <a href="mdndist2.html" class="code" title="function n2 = mdndist2(mixparams, t)">mdndist2</a>(mixparams, t);
0061 c                 = dim_target*ones(ntarget, ncentres);
0062 deltas(:, (ncentres*(1+dim_target)+1):nmixparams) = <span class="keyword">...</span>
0063                       post.*((dist2./mixparams.covars)-c)./(-2);
0064 
0065 <span class="comment">% Now back-propagate deltas through MLP</span>
0066 g = <a href="mlpbkp.html" class="code" title="function g = mlpbkp(net, x, z, deltas)">mlpbkp</a>(net.mlp, x, z, deltas);</pre></div>
<hr><address>Generated on Tue 26-Sep-2006 10:36:21 by <strong><a href="http://www.artefact.tk/software/matlab/m2html/">m2html</a></strong> &copy; 2003</address>
</body>
</html>